import jieba
import numpy as np
import paddle
from paddlenlp.data import Vocab

def convert_example(example, tokenizer, is_test=False):
    """
    jieba 分词, 转换id
    :param example:
    :param tokenizer:
    :param is_test:
    :return:
    """

    input_ids = tokenizer.encode(example['text'])
    valid_length = np.array(len(input_ids), dtype='int64')
    input_ids = np.array(input_ids, dtype='int64')

    if not is_test:
        label = np.array(example['label'], dtype='int64')
        return input_ids, valid_length, label
    else:
        return input_ids, valid_length

def create_dataloader(
        dataset,
        trans_fn=None,
        mode='train',
        batch_size=1,
        use_gpu=True,
        pad_token_id=0,
        batchify_fn=None
):
    """
        Creats dataloader.

        Args:
            dataset(obj:`paddle.io.Dataset`): Dataset instance.
            trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
            mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
            batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
            use_gpu(obj:`bool`, optional, defaults to obj:`False`): Whether to use gpu to run.
            pad_token_id(obj:`int`, optional, defaults to 0): The pad token index.
            batchify_fn(obj:`callable`, optional, defaults to `None`): function to generate mini-batch data by merging
                the sample list, None for only stack each fields of sample in axis
                0(same as :attr::`np.stack(..., axis=0)`).

        Returns:
            dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
        """

    if trans_fn:
        dataset = dataset.apply(trans_fn, lazy=True)

    if mode == 'train':
        sampler = paddle.io.DistributedBatchSampler(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True
        )
    else:
        sampler = paddle.io.BatchSampler(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False
        )

    dataloader = paddle.io.DataLoader(
        dataset,
        batch_sampler=sampler,
        return_list=True,
        collate_fn=batchify_fn
    )

    return dataloader